import gym
from gym import spaces
from gym.utils import seeding
import numpy as np
import math
from gym.envs.registration import register

class StochasticCustomPendulumEnv(gym.Env):
    """
    A stochastic version of the custom Pendulum environment,
    but with shifted/scaled rewards to avoid zero or negative values.
    """

    metadata = {
        "render.modes": ["human"],
        "render_fps": 30
    }

    def __init__(
        self,
        g=10.0,
        action_noise_scale=0.0,
        dynamics_noise_scale=0.0,
        obs_noise_scale=0.0
    ):
        super().__init__()
        self.max_speed = 8
        self.max_torque = 2.0
        self.dt = 0.05
        self.g = g
        self.m = 1.0   # mass
        self.l = 1.0   # length of the pendulum

        # Noise parameters
        self.action_noise_scale = action_noise_scale
        self.dynamics_noise_scale = dynamics_noise_scale
        self.obs_noise_scale = obs_noise_scale

        # For reward shifting
        self.reward_offset = 20.0   # SHIFT to ensure positivity
        self.min_reward = 0.01      # clamp to strictly above zero (optional)

        high = np.array([1.0, 1.0, self.max_speed], dtype=np.float32)
        self.action_space = spaces.Box(
            low=-self.max_torque, high=self.max_torque, shape=(1,), dtype=np.float32
        )
        self.observation_space = spaces.Box(
            low=-high, high=high, dtype=np.float32
        )

        self.seed()
        self.viewer = None
        self.state = None  # (theta, theta_dot)
        self.reset()

    def seed(self, seed=None):
        self.np_random, seed = seeding.np_random(seed)
        return [seed]

    def step(self, action):
        # 1) Add noise to the chosen torque
        raw_torque = float(action[0])
        noisy_torque = raw_torque + self.np_random.normal(0.0, self.action_noise_scale)
        # Clip the final torque to [-max_torque, max_torque]
        u = np.clip(noisy_torque, -self.max_torque, self.max_torque)

        # Current state
        th, thdot = self.state

        # Cost function from classic pendulum
        # angle^2 + 0.1*theta_dot^2 + 0.001*(torque^2)
        costs = self.angle_normalize(th)**2 + 0.1 * (thdot**2) + 0.001 * (u**2)

        # Equation of motion:
        # thdd = -3*g/(2*l)*sin(th + pi) + 3/(m*l^2)*u
        accel = (
            (3.0 / (self.m * self.l**2)) * u
            - (3.0 * self.g / (2.0 * self.l)) * np.sin(th + np.pi)
        )

        # 2) Add noise to the angular acceleration
        if self.dynamics_noise_scale > 0.0:
            accel += self.np_random.normal(0.0, self.dynamics_noise_scale)

        newthdot = thdot + accel * self.dt
        # Clip angular velocity
        newthdot = np.clip(newthdot, -self.max_speed, self.max_speed)
        newth = th + newthdot * self.dt

        self.state = np.array([newth, newthdot], dtype=np.float32)

        # No terminal condition for pendulum
        done = False

        # Original reward is negative: -costs
        original_reward = -costs

        # 3) Shift & clamp so it's never zero/negative
        # e.g. shift by +20 => -1 => +19
        # then clamp to min_reward=0.01 => strictly positive
        shifted_reward = original_reward + self.reward_offset
        final_reward = max(self.min_reward, shifted_reward)

        # 4) Add observation noise if requested
        obs = self._get_obs()
        if self.obs_noise_scale > 0.0:
            obs += self.np_random.normal(0.0, self.obs_noise_scale, size=obs.shape).astype(np.float32)

        return obs, final_reward, done, {}

    def reset(self):
        # Random initial state: [θ in [-pi, pi], θ̇ in [-1, 1]]
        high = np.array([np.pi, 1], dtype=np.float32)
        low = np.array([-np.pi, -1], dtype=np.float32)
        self.state = self.np_random.uniform(low, high).astype(np.float32)

        # Possibly add noise to the returned observation on reset
        obs = self._get_obs()
        if self.obs_noise_scale > 0.0:
            obs += self.np_random.normal(0.0, self.obs_noise_scale, size=obs.shape).astype(np.float32)

        return obs

    def _get_obs(self):
        th, thdot = self.state
        return np.array([np.cos(th), np.sin(th), thdot], dtype=np.float32)

    def render(self, mode="human"):
        # For a minimal environment, skip or replicate the original logic
        pass

    def close(self):
        if self.viewer:
            self.viewer.close()
            self.viewer = None

    @staticmethod
    def angle_normalize(x):
        # keep angle in [-pi, pi]
        return ((x + np.pi) % (2 * np.pi)) - np.pi


register(
    id="StochasticPendulum-v0",
    entry_point="Continuous_Pendulum:StochasticCustomPendulumEnv",
    max_episode_steps=200,
)
